{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "9C18T_ESG5Qx"
      },
      "source": [
        "##### Copyright 2020 Google LLC."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "form",
        "colab": {},
        "colab_type": "code",
        "id": "prnvFmh4G6_H"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "RMb8qzxEHFwS"
      },
      "source": [
        "# Neural Voxel Renderer\n",
        "\n",
        "\n",
        "![](https://storage.googleapis.com/tensorflow-graphics/notebooks/neural_voxel_renderer/teaser2-01.png)\n",
        "\n",
        "This notebook illustrates how to train [Neural Voxel Renderer](https://arxiv.org/abs/1912.04591) (CVPR2020) in Tensorflow 2.\n",
        "\n",
        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/graphics/blob/master/tensorflow_graphics/projects/neural_voxel_renderer/train.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/graphics/blob/master/tensorflow_graphics/projects/neural_voxel_renderer/train.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/table\u003e\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "fkfUjFmxHnRz"
      },
      "source": [
        "## Setup and imports\n",
        "\n",
        "If Tensorflow Graphics is not installed on your system, the following cell can install the Tensorflow Graphics package for you.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "QAvt9i8RHroc"
      },
      "outputs": [],
      "source": [
        "!pip install tensorflow_graphics"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "vDWSNfpF3sd4"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import tensorflow as tf\n",
        "from tensorflow_graphics.projects.neural_voxel_renderer import helpers\n",
        "from tensorflow_graphics.projects.neural_voxel_renderer import models\n",
        "\n",
        "import datetime\n",
        "import matplotlib.pyplot as plt\n",
        "import os\n",
        "import re\n",
        "import time\n",
        "\n",
        "VOXEL_SIZE = (128, 128, 128, 4)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "5ok8IKzUIWCG"
      },
      "source": [
        "## Dataset loading\n",
        "\n",
        "We store our data in TFRecords with custom protobuf messages. Each training element contains the input voxels, the voxel rendering, the light position and the target image. The data is preprocessed (eg the colored voxels have been rendered and placed accordingly). See [this](https://) colab on how to generate the training/testing TFRecords.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "gZG0Vn7M33Vl"
      },
      "outputs": [],
      "source": [
        "# Functions for dataset generation from a set of TFRecords.\n",
        "decode_proto = tf.compat.v1.io.decode_proto\n",
        "\n",
        "\n",
        "def tf_image_normalize(image):\n",
        "  \"\"\"Normalizes the image [-1, 1].\"\"\"\n",
        "  return (2 * tf.cast(image, tf.float32) / 255.) - 1\n",
        "\n",
        "\n",
        "def neural_voxel_plus_proto_get(element):\n",
        "  \"\"\"Extracts the contents from a VoxelSample proto to tensors.\"\"\"\n",
        "  _, values = decode_proto(element,\n",
        "                           \"giotto_blender.NeuralVoxelPlusSample\",\n",
        "                           [\"name\",\n",
        "                            \"voxel_data\",\n",
        "                            \"rerendering_data\",\n",
        "                            \"image_data\",\n",
        "                            \"light_position\"],\n",
        "                           [tf.string,\n",
        "                            tf.string,\n",
        "                            tf.string,\n",
        "                            tf.string,\n",
        "                            tf.float32])\n",
        "  filename = tf.squeeze(values[0])\n",
        "  voxel_data = tf.squeeze(values[1])\n",
        "  rerendering_data = tf.squeeze(values[2])\n",
        "  image_data = tf.squeeze(values[3])\n",
        "  light_position = values[4]\n",
        "  voxels = tf.io.decode_raw(voxel_data, out_type=tf.uint8)\n",
        "  voxels = tf.cast(tf.reshape(voxels, VOXEL_SIZE), tf.float32) / 255.0\n",
        "  rerendering = tf.cast(tf.image.decode_image(rerendering_data, channels=3),\n",
        "                        tf.float32)\n",
        "  rerendering = tf_image_normalize(rerendering)\n",
        "  image = tf.cast(tf.image.decode_image(image_data, channels=3), tf.float32)\n",
        "  image = tf_image_normalize(image)\n",
        "  return filename, voxels, rerendering, image, light_position\n",
        "\n",
        "\n",
        "def _expand_tfrecords_pattern(tfr_pattern):\n",
        "  \"\"\"Helper function to expand a tfrecord patter\"\"\"\n",
        "  def format_shards(m):\n",
        "    return '{}-?????-of-{:0\u003e5}{}'.format(*m.groups())\n",
        "  tfr_pattern = re.sub(r'^([^@]+)@(\\d+)([^@]+)$', format_shards, tfr_pattern)\n",
        "  return tfr_pattern\n",
        "\n",
        "\n",
        "def tfrecords_to_dataset(tfrecords_pattern,\n",
        "                         mapping_func,\n",
        "                         batch_size,\n",
        "                         buffer_size=5000):\n",
        "  \"\"\"Generates a TF Dataset from a rio pattern.\"\"\"\n",
        "  with tf.name_scope('Input/'):\n",
        "    tfrecords_pattern = _expand_tfrecords_pattern(tfrecords_pattern)\n",
        "    dataset = tf.data.Dataset.list_files(tfrecords_pattern, shuffle=True)\n",
        "    dataset = dataset.interleave(tf.data.TFRecordDataset, cycle_length=16)\n",
        "    dataset = dataset.shuffle(buffer_size=buffer_size)\n",
        "    dataset = dataset.map(mapping_func)\n",
        "    dataset = dataset.batch(batch_size)\n",
        "    return dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "DMwsbJ6RL-Wc"
      },
      "outputs": [],
      "source": [
        "# Download the example data, licensed under the Apache License, Version 2.0\n",
        "!rm -rf /tmp/tfrecords_dir/\n",
        "!mkdir /tmp/tfrecords_dir/\n",
        "!wget -P /tmp/tfrecords_dir/ https://storage.googleapis.com/tensorflow-graphics/notebooks/neural_voxel_renderer/train-00012-of-00100.tfrecord"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "vE7eors94Dq9"
      },
      "outputs": [],
      "source": [
        "tfrecords_dir = '/tmp/tfrecords_dir/'\n",
        "tfrecords_pattern = os.path.join(tfrecords_dir, 'train@100.tfrecord')\n",
        "\n",
        "batch_size = 5\n",
        "mapping_function = neural_voxel_plus_proto_get\n",
        "dataset = tfrecords_to_dataset(tfrecords_pattern, mapping_function, batch_size)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "iGlgx3ly5eLb"
      },
      "outputs": [],
      "source": [
        "# Visualize some examples\n",
        "_, ax = plt.subplots(1, 4, figsize=(10, 10))\n",
        "i = 0\n",
        "for a in dataset.take(4):\n",
        "  (filename,\n",
        "   voxels,\n",
        "   vox_render,\n",
        "   target,\n",
        "   light_position) = a\n",
        "  ax[i].imshow(target[0]*0.5+0.5)\n",
        "  ax[i].axis('off')\n",
        "  i += 1\n",
        "plt.show()\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "X_WgOOhjPayp"
      },
      "source": [
        "## Train the model\n",
        "\n",
        "NVR+ is trained with Adam optimizer and L1 and perceptual VGG loss. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "vAjbsYCQs8n8"
      },
      "outputs": [],
      "source": [
        "# ==============================================================================\n",
        "# Defining model and optimizer\n",
        "LEARNING_RATE = 0.002\n",
        "\n",
        "nvr_plus_model = models.neural_voxel_renderer_plus_tf2()\n",
        "optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE)\n",
        "\n",
        "# Saving and logging directories\n",
        "checkpoint_dir = '/tmp/checkpoints'\n",
        "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n",
        "checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=nvr_plus_model)\n",
        "log_dir=\"/tmp/logs/\"\n",
        "summary_writer = tf.summary.create_file_writer(\n",
        "  log_dir + \"fit/\" + datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\"))\n",
        "\n",
        "# ==============================================================================\n",
        "# VGG loss\n",
        "VGG_LOSS_LAYER_NAMES = ['block1_conv1', 'block2_conv1']\n",
        "VGG_LOSS_LAYER_WEIGHTS = [1.0, 0.1]\n",
        "VGG_LOSS_WEIGHT = 0.001\n",
        "\n",
        "def vgg_layers(layer_names):\n",
        "  \"\"\" Creates a vgg model that returns a list of intermediate output values.\"\"\"\n",
        "  # Load our model. Load pretrained VGG, trained on imagenet data\n",
        "  vgg = tf.keras.applications.VGG19(include_top=False, weights='imagenet')\n",
        "  vgg.trainable = False\n",
        "  outputs = [vgg.get_layer(name).output for name in layer_names]\n",
        "  model = tf.keras.Model([vgg.input], outputs)\n",
        "  return model\n",
        "\n",
        "\n",
        "vgg_extractor = vgg_layers(VGG_LOSS_LAYER_NAMES)\n",
        "\n",
        "# ==============================================================================\n",
        "# Total loss\n",
        "def network_loss(output, target):\n",
        "  # L1 loss\n",
        "  l1_loss = tf.reduce_mean(tf.abs(target - output))\n",
        "  # VGG loss\n",
        "  vgg_output = vgg_extractor((output*0.5+0.5)*255)\n",
        "  vgg_target = vgg_extractor((target*0.5+0.5)*255)\n",
        "  vgg_loss = 0\n",
        "  for l in range(len(VGG_LOSS_LAYER_WEIGHTS)):\n",
        "    layer_loss = tf.reduce_mean(tf.square(vgg_target[l] - vgg_output[l]))\n",
        "    vgg_loss += VGG_LOSS_LAYER_WEIGHTS[l]*layer_loss\n",
        "  # Final loss\n",
        "  total_loss = l1_loss + VGG_LOSS_WEIGHT*vgg_loss\n",
        "  return l1_loss, vgg_loss, total_loss"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "SViS-K6OUgCC"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def train_step(input_voxels, input_rendering, input_light, target, epoch):\n",
        "  with tf.GradientTape() as tape:\n",
        "    network_output = nvr_plus_model([input_voxels, \n",
        "                                     input_rendering, \n",
        "                                     input_light],\n",
        "                                     training=True)\n",
        "    l1_loss, vgg_loss, total_loss = network_loss(network_output, target)\n",
        "  network_gradients = tape.gradient(total_loss,\n",
        "                                    nvr_plus_model.trainable_variables)\n",
        "  optimizer.apply_gradients(zip(network_gradients,\n",
        "                                nvr_plus_model.trainable_variables))\n",
        "\n",
        "  with summary_writer.as_default():\n",
        "    tf.summary.scalar('total_loss', total_loss, step=epoch)\n",
        "    tf.summary.scalar('l1_loss', l1_loss, step=epoch)\n",
        "    tf.summary.scalar('vgg_loss', vgg_loss, step=epoch)\n",
        "    tf.summary.image('Vox_rendering', \n",
        "                     input_rendering*0.5+0.5, \n",
        "                     step=epoch, \n",
        "                     max_outputs=4)\n",
        "    tf.summary.image('Prediction', \n",
        "                     network_output*0.5+0.5, \n",
        "                     step=epoch, \n",
        "                     max_outputs=4)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "FO3EHzQ66yxD"
      },
      "outputs": [],
      "source": [
        "def training_loop(train_ds, epochs):\n",
        "  for epoch in range(epochs):\n",
        "    start = time.time()\n",
        "\n",
        "    # Train\n",
        "    for n, (_, voxels, vox_rendering, target, light) in train_ds.enumerate():\n",
        "      print('.', end='')\n",
        "      if (n+1) % 100 == 0:\n",
        "        print()\n",
        "      train_step(voxels, vox_rendering, light, target, epoch)\n",
        "    print()\n",
        "\n",
        "    # saving (checkpoint) the model every 20 epochs\n",
        "    if (epoch + 1) % 20 == 0:\n",
        "      checkpoint.save(file_prefix = checkpoint_prefix)\n",
        "\n",
        "    print ('Time taken for epoch {} is {} sec\\n'.format(epoch + 1,\n",
        "                                                        time.time()-start))\n",
        "  checkpoint.save(file_prefix = checkpoint_prefix)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "yuj29vLY6zPg"
      },
      "outputs": [],
      "source": [
        "NUMBER_OF_EPOCHS = 100\n",
        "training_loop(dataset, NUMBER_OF_EPOCHS)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "last_runtime": {
        "build_target": "",
        "kind": "local"
      },
      "name": "nvr plus training.ipynb",
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
